{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "dataset = np.array([[1, 3, 4],\n",
    "                    [2, 3, 5],\n",
    "                    [1, 2, 3, 5],\n",
    "                    [2, 5]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 频繁项集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_1st_itemset(dataset):\n",
    "    '''\n",
    "    根据初始数据集创建单个物品的项集\n",
    "    '''\n",
    "#     tmp=dataset.flatten(dataset)    # 当输入数据规整时可使用numpy\n",
    "#     return list(map(frozenset,np.unique(tmp)))\n",
    "\n",
    "    tmp = list()\n",
    "    for itemset in dataset:\n",
    "        for item in itemset:\n",
    "            if [item] not in tmp:\n",
    "                tmp.append([item])\n",
    "    tmp.sort()\n",
    "    return list(map(frozenset, tmp))    # frozenset可用作字典key"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def itemset_filter(dataset, itemsets, min_sup=0.5):\n",
    "    '''\n",
    "    过滤小于支持度阈值的项集，返回频繁项集与支持度\n",
    "    itemsets: 项集列表\n",
    "    '''\n",
    "    sup_dict = dict()    # 支持度字典\n",
    "\n",
    "    # 先计数\n",
    "    for sample in dataset:    # 原数据中的项集\n",
    "        for itemset in itemsets:    # 项集列表中的所有项集\n",
    "            if itemset.issubset(sample):\n",
    "                sup_dict[itemset] = sup_dict.get(itemset, 0)+1\n",
    "\n",
    "    len_data = len(dataset)    # 计数/总数=支持度\n",
    "    freq_sets = list()    # 频繁项集列表\n",
    "\n",
    "    # 计算支持度并过滤掉不合格的项集\n",
    "    for itemset in sup_dict:\n",
    "        sup = sup_dict[itemset]/len_data\n",
    "\n",
    "        if sup >= min_sup:\n",
    "            freq_sets.append(itemset)\n",
    "        sup_dict[itemset] = sup\n",
    "\n",
    "    return freq_sets, sup_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extend_itemset(freq_sets):\n",
    "    '''\n",
    "    根据已有的k阶频繁项集生成k+1阶项集\n",
    "    freq_sets: k阶频繁项集的列表\n",
    "    '''\n",
    "    res = list()\n",
    "    raw_size = len(freq_sets)\n",
    "\n",
    "    for i in range(raw_size):\n",
    "        for j in range(i+1, raw_size):    # 两两组合\n",
    "            # 当两项集头部都相等时，才进行合并\n",
    "            head_1 = list(freq_sets[i])[:-1]\n",
    "            head_2 = list(freq_sets[j])[:-1]\n",
    "            if head_1 == head_2:\n",
    "                res.append(freq_sets[i] | freq_sets[j])\n",
    "    return res\n",
    "\n",
    "\n",
    "# itemsets_1st= create_1st_itemset(dataset)\n",
    "# freq_sets_1st,_=itemset_filter(dataset, itemsets_1st, 0.5)\n",
    "# itemsets_2nd = extend_itemset(freq_sets_1st)\n",
    "# freq_sets_2nd, _ = itemset_filter(dataset, itemsets_2nd, 0.5)\n",
    "# itemsets_3th = extend_itemset(freq_sets_2nd)\n",
    "# freq_sets_3th, _ = itemset_filter(dataset, itemsets_3th, 0.5)\n",
    "# freq_sets_3th"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apriori(dataset, min_sup=0.5):\n",
    "    itemsets_1st = create_1st_itemset(dataset)    # 构建1阶项集列表\n",
    "    freq_sets_1st, sup_dict = itemset_filter(dataset, itemsets_1st, min_sup)\n",
    "\n",
    "    all_freq_sets = [freq_sets_1st]    # 所有的频繁项集，以阶数为分隔\n",
    "\n",
    "    while len(all_freq_sets[-1]) > 0:    # 不断寻找高阶频繁集，直到产生空集\n",
    "        cur_itemsets = extend_itemset(all_freq_sets[-1])    # 对最后一阶频繁集升阶，产生高阶项集\n",
    "        cur_freq_sets, cur_sup_dict = itemset_filter(\n",
    "            dataset, cur_itemsets, min_sup)\n",
    "        sup_dict.update(cur_sup_dict)\n",
    "        all_freq_sets.append(cur_freq_sets)\n",
    "\n",
    "    return all_freq_sets, sup_dict\n",
    "\n",
    "\n",
    "# all_freq_sets, sup_dict = apriori(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# all_freq_sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sup_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 关联规则\n",
    "有了能计算频繁项集的方法，接下来就是在频繁项集中找到关联规则。为简单起见，首先计算$2$阶频繁项集可能产生的所有规则。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_conf(freq_set, conseqs, sup_dict, min_conf=0.5):\n",
    "    '''\n",
    "    计算频繁项集中所有单后果规则的置信度\n",
    "    freq_set: 单个频繁项集\n",
    "    conseqs: 可能的后果列表\n",
    "    '''\n",
    "    strong_conseqs = list()    # 强后果列表，返回用于扩展高阶后果\n",
    "    strong_rules = list()    # 规则列表\n",
    "\n",
    "    for cons in conseqs:    # 计算所有可能后果的规则强度\n",
    "        cur_conf = sup_dict[freq_set]/sup_dict[freq_set-cons]    # 共同置信度除以前因置信度\n",
    "\n",
    "        if cur_conf >= min_conf:\n",
    "            # 追加(前因,后果,置信度)三元组\n",
    "            strong_rules.append((freq_set-cons, cons, cur_conf))\n",
    "            strong_conseqs.append(cons)\n",
    "\n",
    "    return strong_rules, strong_conseqs\n",
    "\n",
    "\n",
    "# freq_set = all_freq_sets[1][0]    # 2阶频繁项集列表中的首个项集\n",
    "# single_conseqs = [frozenset([item])\n",
    "#                   for item in freq_set]    # 所有可能的1阶后果，用于查询支持度\n",
    "# compute_conf(freq_set, single_conseqs, sup_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "同频繁项集的发现一样，规则也需要扩充，并且是从后果的角度来扩充。由上述函数可以得到所有$1$阶后果的强规则，那么如何得到$2$阶后果甚至更高阶后果的强规则？需要利用```extend_itemset```函数来对已有的后果列表进行升阶，然后在高阶后果列表中取出可能的后果加入已有后果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 因为这里使用递归手段，所以有些参数和变量需要从函数中独立出来\n",
    "\n",
    "\n",
    "def extend_rules(freq_set, conseqs, sup_dict, tol_rules, min_conf=0.5):\n",
    "    conseq_dim = len(conseqs[0])    # 后果的阶数\n",
    "\n",
    "    if len(freq_set) > conseq_dim+1:    # 只有2阶以上的频繁项集才有可能有高阶后果\n",
    "        high_level_conseqs = extend_itemset(conseqs)    # 后果升阶\n",
    "        cur_rules, high_level_conseqs = compute_conf(\n",
    "            freq_set, high_level_conseqs, sup_dict, min_conf)\n",
    "        tol_rules.extend(cur_rules)\n",
    "\n",
    "        if len(high_level_conseqs) > 1:\n",
    "            extend_rules(freq_set, high_level_conseqs, sup_dict,tol_rules, min_conf)\n",
    "\n",
    "\n",
    "# tol_rules = list()    # 记录所有规则的列表\n",
    "# freq_set = all_freq_sets[2][0]    # 3阶频繁项集列表中的首个项集\n",
    "# single_conseqs = [frozenset([item])\n",
    "#                   for item in freq_set]\n",
    "# extend_rules(freq_set, single_conseqs, sup_dict, tol_rules)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tol_rules"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来的任务就是由低阶往高阶遍历所有频繁项集列表，并发掘强规则。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_rules(all_freq_sets, sup_dict, min_conf=0.5):\n",
    "    tol_rules = list()\n",
    "    for i in range(1, len(all_freq_sets)):    # 跳过1阶频繁项集，因为1阶项集无法产生规则\n",
    "        for freq_set in all_freq_sets[i]:\n",
    "            conseqs = [frozenset([item]) for item in freq_set]    # 初始可能的1阶后果列表\n",
    "            if i > 1:\n",
    "                extend_rules(freq_set, conseqs, sup_dict, tol_rules, min_conf)\n",
    "            else:\n",
    "                cur_rules, _ = compute_conf(\n",
    "                    freq_set, conseqs, sup_dict, min_conf)\n",
    "                tol_rules.extend(cur_rules)\n",
    "\n",
    "    return tol_rules\n",
    "\n",
    "# find_rules(all_freq_sets,sup_dict,min_conf=0.6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# all_freq_sets, sup_dict = apriori(dataset,min_sup=0.5)\n",
    "# find_rules(all_freq_sets,sup_dict,min_conf=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Toydataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(frozenset({'34'}), frozenset({'36'}), 0.834217841799343),\n",
       " (frozenset({'36'}), frozenset({'34'}), 0.9691720493247211),\n",
       " (frozenset({'34'}), frozenset({'85'}), 1.0),\n",
       " (frozenset({'85'}), frozenset({'34'}), 0.9741506646971935),\n",
       " (frozenset({'34'}), frozenset({'86'}), 0.9989891331817033),\n",
       " (frozenset({'86'}), frozenset({'34'}), 0.997728419989904),\n",
       " (frozenset({'34'}), frozenset({'90'}), 0.9219105382865808),\n",
       " (frozenset({'90'}), frozenset({'34'}), 0.9743589743589743),\n",
       " (frozenset({'36'}), frozenset({'85'}), 1.0),\n",
       " (frozenset({'85'}), frozenset({'36'}), 0.8385032003938946),\n",
       " (frozenset({'86'}), frozenset({'36'}), 0.8354366481574962),\n",
       " (frozenset({'36'}), frozenset({'86'}), 0.9718144450968879),\n",
       " (frozenset({'86'}), frozenset({'85'}), 1.0),\n",
       " (frozenset({'85'}), frozenset({'86'}), 0.9753815854258986),\n",
       " (frozenset({'90'}), frozenset({'85'}), 1.0),\n",
       " (frozenset({'85'}), frozenset({'90'}), 0.9217134416543574),\n",
       " (frozenset({'86'}), frozenset({'90'}), 0.9197375063099444),\n",
       " (frozenset({'90'}), frozenset({'86'}), 0.9732905982905983),\n",
       " (frozenset({'34'}), frozenset({'36', '86'}), 0.834217841799343),\n",
       " (frozenset({'86'}), frozenset({'34', '36'}), 0.8331650681474003),\n",
       " (frozenset({'36'}), frozenset({'34', '86'}), 0.9691720493247211),\n",
       " (frozenset({'34'}), frozenset({'36', '85'}), 0.834217841799343),\n",
       " (frozenset({'36'}), frozenset({'34', '85'}), 0.9691720493247211),\n",
       " (frozenset({'85'}), frozenset({'34', '36'}), 0.8126538650910882),\n",
       " (frozenset({'34'}), frozenset({'85', '86'}), 0.9989891331817033),\n",
       " (frozenset({'86'}), frozenset({'34', '85'}), 0.997728419989904),\n",
       " (frozenset({'85'}), frozenset({'34', '86'}), 0.9731659281142294),\n",
       " (frozenset({'34'}), frozenset({'85', '90'}), 0.9219105382865808),\n",
       " (frozenset({'90'}), frozenset({'34', '85'}), 0.9743589743589743),\n",
       " (frozenset({'85'}), frozenset({'34', '90'}), 0.8980797636632201),\n",
       " (frozenset({'34'}), frozenset({'86', '90'}), 0.9208996714682841),\n",
       " (frozenset({'86'}), frozenset({'34', '90'}), 0.9197375063099444),\n",
       " (frozenset({'90'}), frozenset({'34', '86'}), 0.9732905982905983),\n",
       " (frozenset({'86'}), frozenset({'36', '85'}), 0.8354366481574962),\n",
       " (frozenset({'36'}), frozenset({'85', '86'}), 0.9718144450968879),\n",
       " (frozenset({'85'}), frozenset({'36', '86'}), 0.8148695224027572),\n",
       " (frozenset({'86'}), frozenset({'85', '90'}), 0.9197375063099444),\n",
       " (frozenset({'90'}), frozenset({'85', '86'}), 0.9732905982905983),\n",
       " (frozenset({'85'}), frozenset({'86', '90'}), 0.897095027080256),\n",
       " (frozenset({'34', '86'}), frozenset({'36', '85'}), 0.8350619782443715),\n",
       " (frozenset({'36', '86'}), frozenset({'34', '85'}), 0.9972809667673717),\n",
       " (frozenset({'34', '36'}), frozenset({'85', '86'}), 1.0),\n",
       " (frozenset({'85', '86'}), frozenset({'34', '36'}), 0.8331650681474003),\n",
       " (frozenset({'34', '85'}), frozenset({'36', '86'}), 0.834217841799343),\n",
       " (frozenset({'36', '85'}), frozenset({'34', '86'}), 0.9691720493247211),\n",
       " (frozenset({'86'}), frozenset({'34', '36', '85'}), 0.8331650681474003),\n",
       " (frozenset({'34'}), frozenset({'36', '85', '86'}), 0.834217841799343),\n",
       " (frozenset({'36'}), frozenset({'34', '85', '86'}), 0.9691720493247211),\n",
       " (frozenset({'85'}), frozenset({'34', '36', '86'}), 0.8126538650910882),\n",
       " (frozenset({'34', '86'}), frozenset({'85', '90'}), 0.9218315203642803),\n",
       " (frozenset({'86', '90'}), frozenset({'34', '85'}), 1.0),\n",
       " (frozenset({'34', '90'}), frozenset({'85', '86'}), 0.9989035087719298),\n",
       " (frozenset({'85', '86'}), frozenset({'34', '90'}), 0.9197375063099444),\n",
       " (frozenset({'34', '85'}), frozenset({'86', '90'}), 0.9208996714682841),\n",
       " (frozenset({'85', '90'}), frozenset({'34', '86'}), 0.9732905982905983),\n",
       " (frozenset({'86'}), frozenset({'34', '85', '90'}), 0.9197375063099444),\n",
       " (frozenset({'34'}), frozenset({'85', '86', '90'}), 0.9208996714682841),\n",
       " (frozenset({'90'}), frozenset({'34', '85', '86'}), 0.9732905982905983),\n",
       " (frozenset({'85'}), frozenset({'34', '86', '90'}), 0.897095027080256)]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data=[line.split() for line in open('datasets/mushroom/agaricus-lepiota.data').readlines()]\n",
    "all_freq_sets, sup_dict = apriori(data,min_sup=0.8)\n",
    "find_rules(all_freq_sets,sup_dict,min_conf=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
